import torch
import torch.nn as nn
import torch.nn.functional as F

from CtsConv import *

class CtsConvLSTMCell(nn.Module):
    
    def __init__(
        self,
        in_channels, 
        out_channels, 
        kernel_sizes, 
        radius, 
        normalize_attention=False, 
        bias=True
    ):
        super(CtsConvLSTMCell, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_sizes = kernel_sizes
        self.radius = radius
        self.normalize_attention = normalize_attention
        self.bias = bias
        
        self.cts_conv_f = self.__get_ctsconv_layer('cts_conv_f')
        self.cts_conv_i = self.__get_ctsconv_layer('cts_conv_i')
        self.cts_conv_g = self.__get_ctsconv_layer('cts_conv_g')
        self.cts_conv_o = self.__get_ctsconv_layer('cts_conv_o')
        
        weight_shape = (self.in_channels + self.out_channels, self.out_channels)
        self.dense_i = nn.Linear(*weight_shape, bias=self.bias)
        self.dense_f = nn.Linear(*weight_shape, bias=self.bias)
        self.dense_g = nn.Linear(*weight_shape, bias=self.bias)
        self.dense_o = nn.Linear(*weight_shape, bias=self.bias)
        
        if self.bias:
            self.bias_i = self.__get_bias('bias_i')
            self.bias_f = self.__get_bias('bias_f')
            self.bias_o = self.__get_bias('bias_o')
            self.bias_c = self.__get_bias('bias_c')
        
    def forward(self, inputs, states):
        field, center, field_feat, field_mask = inputs
        h, c = states
        
        feats = torch.cat([field_feat, h], axis=-1)
        
        # ========= UPDATE states C
        # i gate
        ans_i = self.cts_conv_i(field, center, feats, field_mask)
        ans_i = ans_i + self.dense_i(feats)
        
        # f gate
        ans_f = self.cts_conv_f(field, center, feats, field_mask)
        ans_f = ans_f + self.dense_f(feats)
        
        # g gate
        ans_g = self.cts_conv_g(field, center, feats, field_mask)
        ans_g = ans_g + self.dense_g(feats)
        
        if self.bias:
            ans_i = ans_i + self.bias_i
            ans_f = ans_f + self.bias_f
            ans_g = ans_g + self.bias_c
        
        # activation
        ans_i = F.relu(ans_i)
        ans_f = F.relu(ans_f)
        ans_g = torch.tanh(ans_g)
        
        c_out = c * ans_f + ans_i * ans_g
        
        # ========= UPDATE states H
        # o gate
        ans_o = self.cts_conv_o(field, center, feats, field_mask)
        ans_o = ans_o + self.dense_o(feats)
        
        if self.bias:
            ans_o = ans_o + self.bias_o
        ans_o = F.relu(ans_o)
        h_out = ans_o * torch.tanh(c_out)
        
        return (h_out, c_out)
        
    def __get_ctsconv_layer(self, layer_name):
        return CtsConv(in_channels=self.in_channels + self.out_channels, 
                       out_channels=self.out_channels, 
                       kernel_sizes=self.kernel_sizes,
                       radius=self.radius,
                       normalize_attention=self.normalize_attention, 
                       layer_name=layer_name)

    def __get_bias(self, bias_name):
        nn_param = torch.rand(self.out_channels) - 0.5
        nn_param = nn_param / (self.out_channels)
        nn_param = nn.Parameter(nn_param)
        self.register_parameter(bias_name, nn_param)
        return nn_param
        
        